# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from asyncrt_test_utils import create_test_device_context, expect_eq
from gpu.host import DeviceContext, Dim
from gpu.host._nvidia_cuda import (
    CUDA,
    CUcontext,
    CUDA_get_current_context,
)
from testing import TestSuite


fn _run_cuda_context(ctx: DeviceContext) raises:
    print("-")
    print("_run_cuda_context()")
    var initial_ctx: CUcontext = CUDA_get_current_context()
    var cuda_ctx: CUcontext = CUDA(ctx)

    with ctx.push_context() as cur_ctx:
        # cur_ctx is still equivalent to the ctx passed in.
        expect_eq(Bool(CUDA(ctx)), Bool(CUDA(cur_ctx)))
        expect_eq(Bool(CUDA(ctx.stream())), Bool(CUDA(cur_ctx.stream())))
        # Make sure that the current CUcontext matches the pushed CUcontext
        expect_eq(Bool(cuda_ctx), Bool(CUDA_get_current_context()))

    expect_eq(Bool(initial_ctx), Bool(CUDA_get_current_context()))
    print("initial CUcontext:", initial_ctx)
    print("CUcontext:", cuda_ctx)


fn _run_cuda_multi_context(ctx0: DeviceContext, ctx1: DeviceContext) raises:
    print("-")
    print("_run_cuda_multi_context()")
    var initial_ctx: CUcontext = CUDA_get_current_context()
    var cuda_ctx0: CUcontext = CUDA(ctx0)
    var cuda_ctx1: CUcontext = CUDA(ctx1)

    with ctx0.push_context() as cur_ctx0:
        # cur_ctx is still equivalent to the ctx passed in.
        expect_eq(Bool(CUDA(ctx0)), Bool(CUDA(cur_ctx0)))
        expect_eq(Bool(CUDA(ctx0.stream())), Bool(CUDA(cur_ctx0.stream())))
        # Make sure that the current CUcontext matches the pushed CUcontext
        expect_eq(Bool(cuda_ctx0), Bool(CUDA_get_current_context()))

        # Nested context pushes save, push and restore
        with ctx1.push_context() as cur_ctx1:
            # cur_ctx is still equivalent to the ctx passed in.
            expect_eq(Bool(CUDA(ctx1)), Bool(CUDA(cur_ctx1)))
            expect_eq(Bool(CUDA(ctx1.stream())), Bool(CUDA(cur_ctx1.stream())))
            # Make sure that the current CUcontext matches the pushed CUcontext
            expect_eq(Bool(cuda_ctx1), Bool(CUDA_get_current_context()))

        # Make sure that the previously pushed CUcontext has been restored.
        expect_eq(Bool(cuda_ctx0), Bool(CUDA_get_current_context()))

    expect_eq(Bool(initial_ctx), Bool(CUDA_get_current_context()))
    print("initial CUcontext:", initial_ctx)
    print("CUcontext(id: 0):", cuda_ctx0)
    print("CUcontext(id: 1):", cuda_ctx1)


fn _run_cuda_stream(ctx: DeviceContext) raises:
    print("-")
    print("_run_cuda_stream()")

    print("Getting the stream.")
    var stream = ctx.stream()

    print("Synchronizing on `stream`.")
    stream.synchronize()
    var cuda_stream = CUDA(stream)
    print("CUstream: ", cuda_stream)


fn _run_cuda_external_function(ctx: DeviceContext) raises:
    print("-")
    print("_run_cuda_external_function()")

    # Signature of externally compiled kernel function
    fn vec_add_sig(
        in0: UnsafePointer[Float32, MutAnyOrigin],
        in1: UnsafePointer[Float32, MutAnyOrigin],
        output: UnsafePointer[Float32, MutAnyOrigin],
        len: Int,
    ):
        pass

    var ptx = """
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-35059454
// Cuda compilation tools, release 12.6, V12.6.85
// Based on NVVM 7.0.1
//

.version 8.5
.target sm_80
.address_size 64

	// .globl	_Z9vectorAddPKfS0_Pfi

.visible .entry _Z9vectorAddPKfS0_Pfi(
	.param .u64 _Z9vectorAddPKfS0_Pfi_param_0,
	.param .u64 _Z9vectorAddPKfS0_Pfi_param_1,
	.param .u64 _Z9vectorAddPKfS0_Pfi_param_2,
	.param .u32 _Z9vectorAddPKfS0_Pfi_param_3
)
{
	.reg .pred 	%p<2>;
	.reg .f32 	%f<4>;
	.reg .b32 	%r<6>;
	.reg .b64 	%rd<11>;


	ld.param.u64 	%rd1, [_Z9vectorAddPKfS0_Pfi_param_0];
	ld.param.u64 	%rd2, [_Z9vectorAddPKfS0_Pfi_param_1];
	ld.param.u64 	%rd3, [_Z9vectorAddPKfS0_Pfi_param_2];
	ld.param.u32 	%r2, [_Z9vectorAddPKfS0_Pfi_param_3];
	mov.u32 	%r3, %ctaid.x;
	mov.u32 	%r4, %ntid.x;
	mov.u32 	%r5, %tid.x;
	mad.lo.s32 	%r1, %r3, %r4, %r5;
	setp.ge.s32 	%p1, %r1, %r2;
	@%p1 bra 	$L__BB0_2;

	cvta.to.global.u64 	%rd4, %rd1;
	mul.wide.s32 	%rd5, %r1, 4;
	add.s64 	%rd6, %rd4, %rd5;
	cvta.to.global.u64 	%rd7, %rd2;
	add.s64 	%rd8, %rd7, %rd5;
	ld.global.f32 	%f1, [%rd8];
	ld.global.f32 	%f2, [%rd6];
	add.f32 	%f3, %f2, %f1;
	cvta.to.global.u64 	%rd9, %rd3;
	add.s64 	%rd10, %rd9, %rd5;
	st.global.f32 	[%rd10], %f3;

$L__BB0_2:
	ret;

}
"""

    comptime LEN = 1024
    comptime BLOCK_DIM = 32

    lhs = ctx.enqueue_create_buffer[DType.float32](LEN)
    lhs.enqueue_fill(2.0)
    rhs = ctx.enqueue_create_buffer[DType.float32](LEN)
    rhs.enqueue_fill(1.0)
    out = ctx.enqueue_create_buffer[DType.float32](LEN)

    func = ctx.load_function[vec_add_sig](
        function_name="_Z9vectorAddPKfS0_Pfi",
        asm=ptx,
    )
    ctx.enqueue_function_checked(
        func,
        lhs,
        rhs,
        out,
        LEN,
        grid_dim=Dim(LEN // BLOCK_DIM),
        block_dim=Dim(BLOCK_DIM),
    )

    with out.map_to_host() as out:
        for i in range(LEN):
            if i < 2:
                print("out[", i, "]: ", out[i])
            if out[i] != 3.0:
                raise Error("Bad value out[", i, "] is ", out[i])


def test_cuda_context():
    var ctx = create_test_device_context()
    _run_cuda_context(ctx)


def test_cuda_stream():
    var ctx = create_test_device_context()
    _run_cuda_stream(ctx)


def test_cuda_external_function():
    var ctx = create_test_device_context()
    _run_cuda_external_function(ctx)


def test_cuda_multi_context():
    if DeviceContext.number_of_devices() > 1:
        var ctx = create_test_device_context()
        _run_cuda_multi_context(ctx, create_test_device_context(device_id=1))


def main():
    TestSuite.discover_tests[__functions_in_module()]().run()
